{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sentence classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will try to predict whether an SMS is a spam or not. To train our model, we will use the `SMSSpam` dataset. This dataset is unbalanced, there is only 13.4% spam. Let's look at the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:52:50.005002Z",
     "iopub.status.busy": "2023-12-04T17:52:50.004369Z",
     "iopub.status.idle": "2023-12-04T17:52:50.415357Z",
     "shell.execute_reply": "2023-12-04T17:52:50.415077Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SMS Spam Collection dataset.\n",
       "\n",
       "The data contains 5,574 items and 1 feature (i.e. SMS body). Spam messages represent\n",
       "13.4% of the dataset. The goal is to predict whether an SMS is a spam or not.\n",
       "\n",
       "      Name  SMSSpam                                                                              \n",
       "      Task  Binary classification                                                                \n",
       "   Samples  5,574                                                                                \n",
       "  Features  1                                                                                    \n",
       "    Sparse  False                                                                                \n",
       "      Path  /Users/max/river_data/SMSSpam/SMSSpamCollection                                      \n",
       "       URL  https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip\n",
       "      Size  466.71 KB                                                                            \n",
       "Downloaded  True                                                                                 "
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import datasets\n",
    "\n",
    "datasets.SMSSpam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:52:50.416917Z",
     "iopub.status.busy": "2023-12-04T17:52:50.416799Z",
     "iopub.status.idle": "2023-12-04T17:52:50.427410Z",
     "shell.execute_reply": "2023-12-04T17:52:50.427185Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'body': 'Go until jurong point, crazy.. Available only in bugis n great world '\n",
      "         'la e buffet... Cine there got amore wat...\\n'}\n",
      "Spam: False\n"
     ]
    }
   ],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "X_y = datasets.SMSSpam()\n",
    "\n",
    "for x, y in X_y:\n",
    "    pprint(x)\n",
    "    print(f'Spam: {y}')\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by building a simple model like a Naive Bayes classifier. We will first preprocess the sentences with a TF-IDF transform that our model can consume. Then, we will measure the accuracy of our model with the AUC metric. This is the right metric to use when the classes are not balanced. In addition, the Naive Bayes models can perform very well on unbalanced datasets and can be used for both binary and multi-class classification problems."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:52:50.428788Z",
     "iopub.status.busy": "2023-12-04T17:52:50.428710Z",
     "iopub.status.idle": "2023-12-04T17:53:07.741457Z",
     "shell.execute_reply": "2023-12-04T17:53:07.741048Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 93.00%"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import feature_extraction\n",
    "from river import naive_bayes\n",
    "from river import metrics\n",
    "\n",
    "X_y = datasets.SMSSpam()\n",
    "\n",
    "model = (\n",
    "    feature_extraction.TFIDF(on='body') | \n",
    "    naive_bayes.BernoulliNB(alpha=0)\n",
    ")\n",
    "\n",
    "metric = metrics.ROCAUC()\n",
    "cm = metrics.ConfusionMatrix()\n",
    "\n",
    "for x, y in X_y:\n",
    "    \n",
    "    y_pred = model.predict_one(x)\n",
    "    \n",
    "    if y_pred is not None:\n",
    "        metric.update(y_pred=y_pred, y_true=y)\n",
    "        cm.update(y_pred=y_pred, y_true=y)\n",
    "    \n",
    "    model.learn_one(x, y)\n",
    "    \n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The confusion matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:07.743587Z",
     "iopub.status.busy": "2023-12-04T17:53:07.743450Z",
     "iopub.status.idle": "2023-12-04T17:53:07.754659Z",
     "shell.execute_reply": "2023-12-04T17:53:07.754387Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "        False   True  \n",
       "False   4,809     17  \n",
       " True     102    645  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results are quite good with this first model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we are working with an imbalanced dataset, we can use the `imblearn` module to rebalance the classes of our dataset. For more information about the `imblearn` module, you can find a dedicated tutorial [here](../imbalanced-learning)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:07.756292Z",
     "iopub.status.busy": "2023-12-04T17:53:07.756192Z",
     "iopub.status.idle": "2023-12-04T17:53:17.349404Z",
     "shell.execute_reply": "2023-12-04T17:53:17.348983Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 94.61%"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import imblearn\n",
    "\n",
    "X_y = datasets.SMSSpam()\n",
    "\n",
    "model = (\n",
    "    feature_extraction.TFIDF(on='body') | \n",
    "    imblearn.RandomUnderSampler(\n",
    "        classifier=naive_bayes.BernoulliNB(alpha=0),\n",
    "        desired_dist={0: .5, 1: .5},\n",
    "        seed=42\n",
    "    )\n",
    ")\n",
    "\n",
    "metric = metrics.ROCAUC()\n",
    "cm = metrics.ConfusionMatrix()\n",
    "\n",
    "for x, y in X_y:\n",
    "    \n",
    "    y_pred = model.predict_one(x)\n",
    "    \n",
    "    if y_pred is not None:\n",
    "        metric.update(y_pred=y_pred, y_true=y)\n",
    "        cm.update(y_pred=y_pred, y_true=y)\n",
    "    \n",
    "    model.learn_one(x, y)\n",
    "    \n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `imblearn` module improved our results. Not bad! We can visualize the pipeline to understand how the data is processed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The confusion matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:17.351734Z",
     "iopub.status.busy": "2023-12-04T17:53:17.351613Z",
     "iopub.status.idle": "2023-12-04T17:53:17.362178Z",
     "shell.execute_reply": "2023-12-04T17:53:17.361751Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "        False   True  \n",
       "False   4,570    255  \n",
       " True      41    706  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:17.364375Z",
     "iopub.status.busy": "2023-12-04T17:53:17.364251Z",
     "iopub.status.idle": "2023-12-04T17:53:17.377988Z",
     "shell.execute_reply": "2023-12-04T17:53:17.377664Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">TFIDF</pre></summary><code class=\"river-estimator-params\">TFIDF (\n",
       "  normalize=True\n",
       "  on=\"body\"\n",
       "  strip_accents=True\n",
       "  lowercase=True\n",
       "  preprocessor=None\n",
       "  tokenizer=None\n",
       "  ngram_range=(1, 1)\n",
       ")\n",
       "</code></details><div class=\"river-component river-wrapper\"><details class=\"river-details\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RandomUnderSampler</pre></summary><code class=\"river-estimator-params\">RandomUnderSampler (\n",
       "  classifier=BernoulliNB (\n",
       "    alpha=0\n",
       "    true_threshold=0.\n",
       "  )\n",
       "  desired_dist={0: 0.5, 1: 0.5}\n",
       "  seed=42\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">BernoulliNB</pre></summary><code class=\"river-estimator-params\">BernoulliNB (\n",
       "  alpha=0\n",
       "  true_threshold=0.\n",
       ")\n",
       "</code></details></div></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  TFIDF (\n",
       "    normalize=True\n",
       "    on=\"body\"\n",
       "    strip_accents=True\n",
       "    lowercase=True\n",
       "    preprocessor=None\n",
       "    tokenizer=None\n",
       "    ngram_range=(1, 1)\n",
       "  ),\n",
       "  RandomUnderSampler (\n",
       "    classifier=BernoulliNB (\n",
       "      alpha=0\n",
       "      true_threshold=0.\n",
       "    )\n",
       "    desired_dist={0: 0.5, 1: 0.5}\n",
       "    seed=42\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's try to use logistic regression to classify messages. We will use different tips to make my model perform better. As in the previous example, we rebalance the classes of our dataset. The logistics regression will be fed from a TF-IDF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:17.379539Z",
     "iopub.status.busy": "2023-12-04T17:53:17.379439Z",
     "iopub.status.idle": "2023-12-04T17:53:17.933593Z",
     "shell.execute_reply": "2023-12-04T17:53:17.933265Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 93.80%"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import linear_model\n",
    "from river import optim\n",
    "from river import preprocessing\n",
    "\n",
    "X_y = datasets.SMSSpam()\n",
    "\n",
    "model = (\n",
    "    feature_extraction.TFIDF(on='body') | \n",
    "    preprocessing.Normalizer() | \n",
    "    imblearn.RandomUnderSampler(\n",
    "        classifier=linear_model.LogisticRegression(\n",
    "            optimizer=optim.SGD(.9), \n",
    "            loss=optim.losses.Log()\n",
    "        ),\n",
    "        desired_dist={0: .5, 1: .5},\n",
    "        seed=42\n",
    "    )\n",
    ")\n",
    "\n",
    "metric = metrics.ROCAUC()\n",
    "cm = metrics.ConfusionMatrix()\n",
    "\n",
    "for x, y in X_y:\n",
    "    \n",
    "    y_pred = model.predict_one(x)\n",
    "\n",
    "    metric.update(y_pred=y_pred, y_true=y)\n",
    "    cm.update(y_pred=y_pred, y_true=y)\n",
    "    \n",
    "    model.learn_one(x, y)\n",
    "    \n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The confusion matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:17.935653Z",
     "iopub.status.busy": "2023-12-04T17:53:17.935494Z",
     "iopub.status.idle": "2023-12-04T17:53:17.947268Z",
     "shell.execute_reply": "2023-12-04T17:53:17.946889Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "        False   True  \n",
       "False   4,584    243  \n",
       " True      55    692  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:17.949234Z",
     "iopub.status.busy": "2023-12-04T17:53:17.949106Z",
     "iopub.status.idle": "2023-12-04T17:53:17.960965Z",
     "shell.execute_reply": "2023-12-04T17:53:17.960655Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">TFIDF</pre></summary><code class=\"river-estimator-params\">TFIDF (\n",
       "  normalize=True\n",
       "  on=\"body\"\n",
       "  strip_accents=True\n",
       "  lowercase=True\n",
       "  preprocessor=None\n",
       "  tokenizer=None\n",
       "  ngram_range=(1, 1)\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">Normalizer</pre></summary><code class=\"river-estimator-params\">Normalizer (\n",
       "  order=2\n",
       ")\n",
       "</code></details><div class=\"river-component river-wrapper\"><details class=\"river-details\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RandomUnderSampler</pre></summary><code class=\"river-estimator-params\">RandomUnderSampler (\n",
       "  classifier=LogisticRegression (\n",
       "    optimizer=SGD (\n",
       "      lr=Constant (\n",
       "        learning_rate=0.9\n",
       "      )\n",
       "    )\n",
       "    loss=Log (\n",
       "      weight_pos=1.\n",
       "      weight_neg=1.\n",
       "    )\n",
       "    l2=0.\n",
       "    l1=0.\n",
       "    intercept_init=0.\n",
       "    intercept_lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "    clip_gradient=1e+12\n",
       "    initializer=Zeros ()\n",
       "  )\n",
       "  desired_dist={0: 0.5, 1: 0.5}\n",
       "  seed=42\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">LogisticRegression</pre></summary><code class=\"river-estimator-params\">LogisticRegression (\n",
       "  optimizer=SGD (\n",
       "    lr=Constant (\n",
       "      learning_rate=0.9\n",
       "    )\n",
       "  )\n",
       "  loss=Log (\n",
       "    weight_pos=1.\n",
       "    weight_neg=1.\n",
       "  )\n",
       "  l2=0.\n",
       "  l1=0.\n",
       "  intercept_init=0.\n",
       "  intercept_lr=Constant (\n",
       "    learning_rate=0.01\n",
       "  )\n",
       "  clip_gradient=1e+12\n",
       "  initializer=Zeros ()\n",
       ")\n",
       "</code></details></div></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  TFIDF (\n",
       "    normalize=True\n",
       "    on=\"body\"\n",
       "    strip_accents=True\n",
       "    lowercase=True\n",
       "    preprocessor=None\n",
       "    tokenizer=None\n",
       "    ngram_range=(1, 1)\n",
       "  ),\n",
       "  Normalizer (\n",
       "    order=2\n",
       "  ),\n",
       "  RandomUnderSampler (\n",
       "    classifier=LogisticRegression (\n",
       "      optimizer=SGD (\n",
       "        lr=Constant (\n",
       "          learning_rate=0.9\n",
       "        )\n",
       "      )\n",
       "      loss=Log (\n",
       "        weight_pos=1.\n",
       "        weight_neg=1.\n",
       "      )\n",
       "      l2=0.\n",
       "      l1=0.\n",
       "      intercept_init=0.\n",
       "      intercept_lr=Constant (\n",
       "        learning_rate=0.01\n",
       "      )\n",
       "      clip_gradient=1e+12\n",
       "      initializer=Zeros ()\n",
       "    )\n",
       "    desired_dist={0: 0.5, 1: 0.5}\n",
       "    seed=42\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results of the logistic regression are quite good but still inferior to the naive Bayes model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to use word embeddings to improve our logistic regression. Word embeddings allow you to represent a word as a vector. Embeddings are developed to build semantically rich vectors. For instance, the vector which represents the word **python** should be close to the vector which represents the word **programming**. We will use [spaCy](https://spacy.io/) to convert our sentence to vectors. spaCy converts a sentence to a vector by calculating the average of the embeddings of the words in the sentence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can download pre-trained embeddings in many languages. We will use English pre-trained embeddings as our SMS are in English."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The command below allows you to download the pre-trained embeddings that spaCy makes available. More information about spaCy and its installation may be found here [here](https://spacy.io/usage)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```sh\n",
    "python -m spacy download en_core_web_sm\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we create a custom transformer to convert an input sentence to a dict of floats. We will integrate this transformer into our pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:17.963440Z",
     "iopub.status.busy": "2023-12-04T17:53:17.963311Z",
     "iopub.status.idle": "2023-12-04T17:53:18.493616Z",
     "shell.execute_reply": "2023-12-04T17:53:18.493319Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import spacy\n",
    "\n",
    "from river.base import Transformer\n",
    "\n",
    "class Embeddings(Transformer):\n",
    "    \"\"\"My custom transformer, word embedding using spaCy.\"\"\"\n",
    "    \n",
    "    def __init__(self, on: str):\n",
    "        self.on = on\n",
    "        self.embeddings = spacy.load('en_core_web_sm')\n",
    "        \n",
    "    def transform_one(self, x, y=None):\n",
    "        return {dimension: xi for dimension, xi in enumerate(self.embeddings(x[self.on]).vector)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's train our logistic regression:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:53:18.495764Z",
     "iopub.status.busy": "2023-12-04T17:53:18.495554Z",
     "iopub.status.idle": "2023-12-04T17:54:14.281153Z",
     "shell.execute_reply": "2023-12-04T17:54:14.280920Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 91.31%"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_y = datasets.SMSSpam()\n",
    "\n",
    "model = (\n",
    "    Embeddings(on='body') | \n",
    "    preprocessing.Normalizer() |\n",
    "    imblearn.RandomOverSampler(\n",
    "        classifier=linear_model.LogisticRegression(\n",
    "            optimizer=optim.SGD(.5), \n",
    "            loss=optim.losses.Log()\n",
    "        ),\n",
    "        desired_dist={0: .5, 1: .5},\n",
    "        seed=42\n",
    "    )\n",
    ")\n",
    "\n",
    "metric = metrics.ROCAUC()\n",
    "cm = metrics.ConfusionMatrix()\n",
    "\n",
    "for x, y in X_y:\n",
    "    \n",
    "    y_pred = model.predict_one(x)\n",
    "    \n",
    "    metric.update(y_pred=y_pred, y_true=y)\n",
    "    cm.update(y_pred=y_pred, y_true=y)\n",
    "    \n",
    "    model.learn_one(x, y)\n",
    "\n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The confusion matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:54:14.282735Z",
     "iopub.status.busy": "2023-12-04T17:54:14.282623Z",
     "iopub.status.idle": "2023-12-04T17:54:14.294872Z",
     "shell.execute_reply": "2023-12-04T17:54:14.294638Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "        False   True  \n",
       "False   4,537    290  \n",
       " True      85    662  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:54:14.296162Z",
     "iopub.status.busy": "2023-12-04T17:54:14.296083Z",
     "iopub.status.idle": "2023-12-04T17:54:14.307160Z",
     "shell.execute_reply": "2023-12-04T17:54:14.306923Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">Embeddings</pre></summary><code class=\"river-estimator-params\">Embeddings (\n",
       "  on=\"body\"\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">Normalizer</pre></summary><code class=\"river-estimator-params\">Normalizer (\n",
       "  order=2\n",
       ")\n",
       "</code></details><div class=\"river-component river-wrapper\"><details class=\"river-details\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RandomOverSampler</pre></summary><code class=\"river-estimator-params\">RandomOverSampler (\n",
       "  classifier=LogisticRegression (\n",
       "    optimizer=SGD (\n",
       "      lr=Constant (\n",
       "        learning_rate=0.5\n",
       "      )\n",
       "    )\n",
       "    loss=Log (\n",
       "      weight_pos=1.\n",
       "      weight_neg=1.\n",
       "    )\n",
       "    l2=0.\n",
       "    l1=0.\n",
       "    intercept_init=0.\n",
       "    intercept_lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "    clip_gradient=1e+12\n",
       "    initializer=Zeros ()\n",
       "  )\n",
       "  desired_dist={0: 0.5, 1: 0.5}\n",
       "  seed=42\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">LogisticRegression</pre></summary><code class=\"river-estimator-params\">LogisticRegression (\n",
       "  optimizer=SGD (\n",
       "    lr=Constant (\n",
       "      learning_rate=0.5\n",
       "    )\n",
       "  )\n",
       "  loss=Log (\n",
       "    weight_pos=1.\n",
       "    weight_neg=1.\n",
       "  )\n",
       "  l2=0.\n",
       "  l1=0.\n",
       "  intercept_init=0.\n",
       "  intercept_lr=Constant (\n",
       "    learning_rate=0.01\n",
       "  )\n",
       "  clip_gradient=1e+12\n",
       "  initializer=Zeros ()\n",
       ")\n",
       "</code></details></div></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  Embeddings (\n",
       "    on=\"body\"\n",
       "  ),\n",
       "  Normalizer (\n",
       "    order=2\n",
       "  ),\n",
       "  RandomOverSampler (\n",
       "    classifier=LogisticRegression (\n",
       "      optimizer=SGD (\n",
       "        lr=Constant (\n",
       "          learning_rate=0.5\n",
       "        )\n",
       "      )\n",
       "      loss=Log (\n",
       "        weight_pos=1.\n",
       "        weight_neg=1.\n",
       "      )\n",
       "      l2=0.\n",
       "      l1=0.\n",
       "      intercept_init=0.\n",
       "      intercept_lr=Constant (\n",
       "        learning_rate=0.01\n",
       "      )\n",
       "      clip_gradient=1e+12\n",
       "      initializer=Zeros ()\n",
       "    )\n",
       "    desired_dist={0: 0.5, 1: 0.5}\n",
       "    seed=42\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results of the logistic regression using spaCy embeddings are lower than those obtained with TF-IDF values. We could surely improve the results by cleaning up the text. We could also use embeddings more suited to our dataset. However, on this problem, the logistic regression is not better than the Naive Bayes model. No free lunch today."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "vscode": {
   "interpreter": {
    "hash": "14b46bd212fa4dd89e3980db6ba7efbb9fe535833e1e483b914b71733e0a56d2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
